"""
Demographic polynomial construction following Fair-Dominguez (1991) /
Higgins (1998) / Koomen & Wicht (2020).

Processes UN WPP data into:
1. 17 age-group population shares (0-4, 5-9, ..., 75-79, 80+)
2. GDP-weighted demeaned shares (deviations from world average)
3. Polynomial-transformed variables Z1, Z2, Z3 for regression
4. Recovery of implied age-group coefficients from estimated γ's
"""

import pandas as pd
import numpy as np
from pathlib import Path

RAW_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/raw")
PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")

# 17 standard 5-year age groups
AGE_GROUPS = [
    '0-4', '5-9', '10-14', '15-19', '20-24', '25-29', '30-34',
    '35-39', '40-44', '45-49', '50-54', '55-59', '60-64', '65-69',
    '70-74', '75-79', '80+'
]
G = len(AGE_GROUPS)  # = 17


def process_un_wpp(raw_df=None):
    """
    Process UN WPP 2024 data into country-year age-group shares.

    The WPP2024 bulk CSV has columns:
        ISO3_code, LocTypeName, Time, AgeGrp, AgeGrpStart, PopTotal, ...
    with 21 age groups (0-4 through 95-99, 100+).

    We collapse into 17 standard groups (0-4 through 75-79, 80+)
    by aggregating 80-84, 85-89, 90-94, 95-99, 100+ into 80+.

    Returns DataFrame with columns:
        iso3, year, n_1, n_2, ..., n_17, total_pop
    where n_g = population in age group g / total population.
    """
    if raw_df is None:
        fpath = RAW_DIR / "un_wpp_population_by_age.csv"
        if not fpath.exists():
            raise FileNotFoundError(f"UN WPP data not found at {fpath}. Run download.py first.")
        raw_df = pd.read_csv(fpath, low_memory=False)

    print(f"Processing UN WPP data: {raw_df.shape}")

    df = raw_df.copy()

    # Filter to countries only (exclude regions, world, etc.)
    if 'LocTypeName' in df.columns:
        df = df[df['LocTypeName'] == 'Country/Area']
        print(f"  After filtering to countries: {len(df)} rows")

    # Map age groups to our 17 standard groups
    # Groups 1-16: 0-4 through 75-79 (direct mapping)
    # Group 17 (80+): aggregate 80-84, 85-89, 90-94, 95-99, 100+
    age_to_group = {}
    for g, ag in enumerate(AGE_GROUPS[:16], 1):  # groups 1-16
        age_to_group[ag] = g
    # Everything 80+ maps to group 17
    for ag in ['80-84', '85-89', '90-94', '95-99', '100+']:
        age_to_group[ag] = 17

    df['age_group_idx'] = df['AgeGrp'].map(age_to_group)
    df = df.dropna(subset=['age_group_idx', 'ISO3_code'])
    df['age_group_idx'] = df['age_group_idx'].astype(int)
    df['PopTotal'] = pd.to_numeric(df['PopTotal'], errors='coerce')

    print(f"  After age mapping: {len(df)} rows, {df['ISO3_code'].nunique()} countries")

    # Aggregate by country-year-group (sums 80+ sub-groups into group 17)
    grouped = df.groupby(['ISO3_code', 'Time', 'age_group_idx'])['PopTotal'].sum().reset_index()

    # Pivot to wide format: one row per country-year, columns pop_1..pop_17
    pivot = grouped.pivot_table(
        index=['ISO3_code', 'Time'],
        columns='age_group_idx',
        values='PopTotal',
        aggfunc='sum'
    ).reset_index()

    # Rename columns
    rename = {'ISO3_code': 'iso3', 'Time': 'year'}
    for g in range(1, G + 1):
        if g in pivot.columns:
            rename[g] = f'pop_{g}'
    pivot = pivot.rename(columns=rename)

    # Compute total population and shares
    pop_cols = [f'pop_{g}' for g in range(1, G + 1)]
    pivot['total_pop'] = pivot[pop_cols].sum(axis=1)

    share_cols = [f'n_{g}' for g in range(1, G + 1)]
    for g in range(1, G + 1):
        pivot[f'n_{g}'] = pivot[f'pop_{g}'] / pivot['total_pop']

    # Add country name from original data
    country_names = df[['ISO3_code', 'Location']].drop_duplicates().rename(
        columns={'ISO3_code': 'iso3', 'Location': 'country'}
    )
    pivot = pivot.merge(country_names, on='iso3', how='left')
    pivot['iso_numeric'] = np.nan

    # Sanity check: shares should sum to ~1
    share_sum = pivot[share_cols].sum(axis=1)
    bad = (share_sum < 0.99) | (share_sum > 1.01)
    if bad.any():
        print(f"  Warning: {bad.sum()} rows have share sums outside [0.99, 1.01]")
        pivot = pivot[~bad]

    result = pivot[['iso3', 'iso_numeric', 'country', 'year', 'total_pop'] + share_cols].copy()
    result = result.dropna(subset=share_cols)

    PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
    result.to_csv(PROCESSED_DIR / "demographic_shares.csv", index=False)
    print(f"  Saved demographic shares: {result.shape}")
    print(f"  Countries: {result['iso3'].nunique()}, Years: {result['year'].min()}-{result['year'].max()}")
    return result


# ---------------------------------------------------------------------------
# Polynomial construction
# ---------------------------------------------------------------------------

def construct_polynomial_variables(shares_df, gdp_df=None, P=3):
    """
    Construct Fair-Dominguez polynomial demographic variables.

    Following Koomen & Wicht (2020):
    1. Age-group shares n_git for G=17 groups
    2. Optionally demean relative to GDP-weighted world average
    3. Impose polynomial constraint: α_g = Σ_{p=0}^{P} γ_p * g^p
    4. Impose zero-sum constraint: Σ α_g = 0
    5. Construct Z_p = Σ_g g^p * n_git for p=1,...,P

    Parameters
    ----------
    shares_df : DataFrame with iso3, year, n_1,...,n_17
    gdp_df : DataFrame with iso3, year, gdp (for weighting). If None, uses equal weights.
    P : int, polynomial degree (default 3 for cubic)

    Returns
    -------
    DataFrame with iso3, year, Z_1, Z_2, Z_3, plus demeaned shares
    """
    share_cols = [f'n_{g}' for g in range(1, G + 1)]

    # Verify we have all share columns
    missing = [c for c in share_cols if c not in shares_df.columns]
    if missing:
        raise ValueError(f"Missing share columns: {missing}")

    df = shares_df[['iso3', 'year'] + share_cols].copy()

    # --- Step 1: GDP-weighted demeaning ---
    if gdp_df is not None:
        df = _demean_shares(df, gdp_df, share_cols)
    else:
        # Simple demeaning: subtract unweighted cross-sectional mean per year
        for col in share_cols:
            year_mean = df.groupby('year')[col].transform('mean')
            df[f'd_{col}'] = df[col] - year_mean
        d_share_cols = [f'd_n_{g}' for g in range(1, G + 1)]
    d_share_cols = [f'd_n_{g}' for g in range(1, G + 1)]

    # --- Step 2: Construct polynomial basis variables Z_p ---
    # Z_p,it = Σ_{g=1}^{G} g^p * dn_git
    # where dn_git are the demeaned shares
    # g is indexed 1 through 17

    g_indices = np.arange(1, G + 1)  # [1, 2, ..., 17]

    for p in range(1, P + 1):
        g_powers = g_indices ** p  # [1^p, 2^p, ..., 17^p]
        df[f'Z_{p}'] = sum(g_powers[g - 1] * df[f'd_n_{g}'] for g in range(1, G + 1))

    # Also construct Z_0 for the zero-sum constraint
    # Z_0 = Σ_g dn_git = 0 by construction (shares sum to 1, demeaned sums to 0)
    # So the zero-sum constraint Σ α_g = 0 is automatically satisfied
    # because α_g = Σ γ_p g^p and we set γ_0 via the constraint

    result_cols = ['iso3', 'year'] + [f'Z_{p}' for p in range(1, P + 1)]
    result = df[result_cols].copy()

    # Also keep the demeaned shares for later analysis
    for col in d_share_cols:
        result[col] = df[col]

    result.to_csv(PROCESSED_DIR / "demographic_polynomials.csv", index=False)
    print(f"  Saved polynomial variables: {result.shape}")
    return result


def _demean_shares(df, gdp_df, share_cols):
    """Demean age-group shares relative to GDP-weighted world average."""
    # Merge GDP
    merged = df.merge(
        gdp_df[['iso3', 'year', 'gdp']].dropna(),
        on=['iso3', 'year'],
        how='left'
    )

    # For projection years beyond WEO coverage, forward-fill last available GDP
    # so that GDP-weighted demeaning is consistent across historical and projected
    merged = merged.sort_values(['iso3', 'year'])
    merged['gdp'] = merged.groupby('iso3')['gdp'].ffill()
    merged['gdp'] = merged['gdp'].fillna(0)

    # Compute GDP-weighted world average shares per year
    for col in share_cols:
        weighted = merged.groupby('year').apply(
            lambda x: np.average(x[col], weights=x['gdp']) if x['gdp'].sum() > 0
            else x[col].mean(),
            include_groups=False
        ).rename(f'world_{col}')
        merged = merged.merge(weighted.reset_index(), on='year', how='left')
        merged[f'd_{col}'] = merged[col] - merged[f'world_{col}']

    # Copy back to df
    for col in share_cols:
        df[f'd_{col}'] = merged[f'd_{col}'].values

    return df


def recover_age_coefficients(gamma_hat, P=3):
    """
    Recover implied age-group coefficients α_g from estimated polynomial
    coefficients γ_1, ..., γ_P.

    The polynomial constraint is:
        α_g = γ_0 + γ_1*g + γ_2*g^2 + γ_3*g^3

    The zero-sum constraint pins down γ_0:
        Σ_g α_g = 0
        Σ_g [γ_0 + γ_1*g + γ_2*g^2 + γ_3*g^3] = 0
        G*γ_0 + γ_1*Σg + γ_2*Σg^2 + γ_3*Σg^3 = 0
        γ_0 = -(γ_1*Σg + γ_2*Σg^2 + γ_3*Σg^3) / G

    Parameters
    ----------
    gamma_hat : array-like of length P, the estimated γ_1,...,γ_P
    P : polynomial degree

    Returns
    -------
    alpha : array of length G=17, the implied age-group coefficients
    """
    g = np.arange(1, G + 1)

    # Compute γ_0 from zero-sum constraint
    gamma_0 = 0
    for p in range(1, P + 1):
        gamma_0 -= gamma_hat[p - 1] * np.sum(g ** p)
    gamma_0 /= G

    # Compute α_g
    alpha = np.full(G, gamma_0)
    for p in range(1, P + 1):
        alpha += gamma_hat[p - 1] * (g ** p)

    return alpha


def compute_dependency_ratios(shares_df):
    """
    Compute standard demographic dependency ratios from age-group shares.

    Returns DataFrame with:
    - youth_dep: population 0-14 / population 15-64
    - old_dep (OADR): population 65+ / population 15-64
    - total_dep: (0-14 + 65+) / 15-64
    - working_age_share: population 15-64 / total
    - median_age_approx: approximate median age from age group distribution
    """
    df = shares_df.copy()

    # Youth (0-14): groups 1-3 (0-4, 5-9, 10-14)
    df['youth'] = df['n_1'] + df['n_2'] + df['n_3']

    # Working age (15-64): groups 4-13 (15-19 through 60-64)
    wa_cols = [f'n_{g}' for g in range(4, 14)]
    df['working_age'] = df[wa_cols].sum(axis=1)

    # Old (65+): groups 14-17 (65-69, 70-74, 75-79, 80+)
    old_cols = [f'n_{g}' for g in range(14, 18)]
    df['old'] = df[old_cols].sum(axis=1)

    df['youth_dep'] = df['youth'] / df['working_age']
    df['old_dep'] = df['old'] / df['working_age']
    df['total_dep'] = (df['youth'] + df['old']) / df['working_age']
    df['working_age_share'] = df['working_age']

    result = df[['iso3', 'year', 'youth_dep', 'old_dep', 'total_dep', 'working_age_share']]
    return result


def compute_future_oadr(shares_df, horizon=20):
    """
    Compute expected future old-age dependency ratio (OADR) at t+horizon.

    Uses actual data for past periods and UN medium projections for future.
    This is used as a control variable in the EBA model.
    """
    df = shares_df[['iso3', 'year']].copy()

    # Compute OADR for each country-year
    old_cols = [f'n_{g}' for g in range(14, 18)]
    wa_cols = [f'n_{g}' for g in range(4, 14)]
    shares_df['oadr'] = shares_df[old_cols].sum(axis=1) / shares_df[wa_cols].sum(axis=1)

    # Look up OADR at t+horizon
    future = shares_df[['iso3', 'year', 'oadr']].copy()
    future['year'] = future['year'] - horizon  # shift back so we can merge

    df = df.merge(
        future.rename(columns={'oadr': f'oadr_plus{horizon}'}),
        on=['iso3', 'year'],
        how='left'
    )

    return df


if __name__ == "__main__":
    # Test processing
    shares = process_un_wpp()
    print(f"\nShares sample:\n{shares.head()}")

    # Test polynomial construction (without GDP weighting for now)
    polys = construct_polynomial_variables(shares)
    print(f"\nPolynomial variables sample:\n{polys.head()}")
